ScatterNd
依据指定的索引 indices,将更新值 updates 累加到输出张量 output 的对应位置。
\[output[indices_i] = output[indices_i] + updates_i\]
- 输入:
output - 输出张量的起始地址(计算前作为基础值,计算后存储结果)。
output_shape - 输出张量的形状数组。
output_ndim - 输出张量的维度数。
indices - 索引数据地址,其形状通常为
(num_slices, indices_depth)。indices_shape - 索引张量的形状数组。
indices_ndim - 索引张量的维度数。
updates - 更新数据地址。
core_mask(int, 可选) - 核掩码(仅适用于共享存储版本)。
- 输出:
output - 计算结果地址。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128
MT7004 支持 fp16, fp32, int16, int32, cplx64
该算子在多核实现中由于涉及随机访存写,通常直接在 DDR 空间操作。
张量维度最大支持 8 维。
共享存储版本:
-
void i8_scatter_nd_s(int8_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int8_t *updates, int core_mask)
-
void i16_scatter_nd_s(int16_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int16_t *updates, int core_mask)
-
void i32_scatter_nd_s(int32_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int32_t *updates, int core_mask)
-
void hp_scatter_nd_s(half *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, half *updates, int core_mask)
-
void fp_scatter_nd_s(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates, int core_mask)
-
void dp_scatter_nd_s(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates, int core_mask)
-
void c64_scatter_nd_s(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates, int core_mask)
-
void c128_scatter_nd_s(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates, int core_mask)
C调用示例:
1//FT78NE示例(共享存储) 2#include <stdio.h> 3#include "78NE/utils.h" 4 5int main() { 6 float *output = (float *)0xA0000000; // 基础输出张量在 DDR 7 float *updates = (float *)0xB0000000; // 更新值在 DDR 8 int *indices = (int *)0xC0000000; // 索引在 DDR 9 int out_shape[] = {4, 4, 4}; 10 int ind_shape[] = {5, 2}; 11 int out_ndim = 3; 12 int ind_ndim = 2; 13 int core_mask = 0x0B; 14 15 fp_scatter_nd_s(output, out_shape, out_ndim, indices, ind_shape, ind_ndim, updates, core_mask); 16 return 0; 17}
私有存储版本:
-
void i8_scatter_nd_p(int8_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int8_t *updates)
-
void i16_scatter_nd_p(int16_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int16_t *updates)
-
void i32_scatter_nd_p(int32_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int32_t *updates)
-
void hp_scatter_nd_p(half *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, half *updates)
-
void fp_scatter_nd_p(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates)
-
void dp_scatter_nd_p(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates)
-
void c64_scatter_nd_p(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates)
-
void c128_scatter_nd_p(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates)
C调用示例:
1//MT7004 示例 2#include <stdio.h> 3 4int main() { 5 float *output = (float *)0x10810000; 6 float *updates = (float *)0x10820000; 7 int *indices = (int *)0x10830000; 8 int out_shape[] = {4, 4, 4}; 9 int ind_shape[] = {5, 2}; 10 int out_ndim = 3; 11 int ind_ndim = 2; 12 13 fp_scatter_nd_p(output, out_shape, out_ndim, indices, ind_shape, ind_ndim, updates); 14 return 0; 15}